/*____________________________________________________________________________
		Copyright (C) 2000 Network Associates, Inc.
        All rights reserved.

        $Id: CFileImpDrvNT.cpp,v 1.5 2000/05/13 06:34:21 nryan Exp $
____________________________________________________________________________*/

#include "pgpClassesConfig.h"

#include "CFileImpDrvNT.h"
#include "CUnicodeString.h"

_USING_PGP

_UNNAMED_BEGIN

// Constants

enum
{
	kPagingIoResourceOffsetInFCB	= 0x0c, 
	kReplaceResourceWaitIntervalMs	= 50
};

_UNNAMED_END

// Class CFileImpDrvNT member functions

CFileImpDrvNT::CFileImpDrvNT() : 
	mFileHandle(NULL), mPOldPagingIoResource(NULL), mPPPagingIoResource(NULL)
{
}

CFileImpDrvNT::~CFileImpDrvNT()
{
	if (IsOpened())
		Close();
}

CComboError 
CFileImpDrvNT::GetLength(PGPUInt64& length) const
{
	pgpAssert(IsOpened());

	CComboError					error;
	FILE_STANDARD_INFORMATION	fileInfo;
	IO_STATUS_BLOCK				ioStatus;
	
	error.err = ZwQueryInformationFile(mFileHandle, &ioStatus, &fileInfo, 
		sizeof(fileInfo), FileStandardInformation);

	if (error.HaveNonPGPError())
		error.pgpErr = kPGPError_FileOpFailed;

	if (error.IsntError())
		length = fileInfo.EndOfFile.QuadPart;

	return error;
}

CComboError 
CFileImpDrvNT::GetUniqueFileId(PGPUInt64& fileId) const
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
	return CComboError(CComboError::kPGPError, kPGPError_FeatureNotAvailable);
}

CComboError 
CFileImpDrvNT::SetIsCompressed(PGPBoolean isCompressed)
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
	return CComboError(CComboError::kPGPError, kPGPError_FeatureNotAvailable);
}

CComboError 
CFileImpDrvNT::SetLength(PGPUInt64 length)
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
	return CComboError(CComboError::kPGPError, kPGPError_FeatureNotAvailable);
}

void 
CFileImpDrvNT::Flush()
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
}

CComboError 
CFileImpDrvNT::Open(const char *path, PGPUInt32 flags)
{
	pgpAssert(!IsOpened());
	pgpAssertStrValid(path);

	CComboError	error;
	error = mPath.Assign(path);

	if (error.IsntError())
	{
		CUnicodeString	uniPath;
		error = uniPath.Status();

		// Prepare the path. Must treat UNC paths special.
		if (error.IsntError())
		{
			if (IsUNCPath(path))
			{
				error = uniPath.Assign(path + 2);

				if (error.IsntError())
					error = uniPath.Prepend(kNTUNCLinkPathPrefix);
			}
			else
			{
				error = uniPath.Assign(path);

				if (error.IsntError())
					error = uniPath.Prepend(kNTLinkPathPrefix);
			}
		}

		if (error.IsntError())
		{
			IO_STATUS_BLOCK		ioStatus;
			OBJECT_ATTRIBUTES	objAttribs;

			uniPath.RemoveZeroTermination();

			// Initialize object attributes
			InitializeObjectAttributes(&objAttribs, uniPath.Get(), 
				OBJ_CASE_INSENSITIVE, NULL, NULL);

			// Set create options.
			PGPUInt32	createOpts	= FILE_NON_DIRECTORY_FILE | 
				FILE_SYNCHRONOUS_IO_NONALERT;

			if (flags & CFile::kNoBufferingFlag)
				createOpts |= FILE_NO_INTERMEDIATE_BUFFERING;

			// Set access flags.
			ACCESS_MASK	accessMask	= FILE_GENERIC_READ | SYNCHRONIZE;

			if (!(flags & CFile::kReadOnlyFlag))
				accessMask |= FILE_GENERIC_WRITE;

			// Set share flags.
			PGPUInt32	shareAccess	= NULL;

			if (!(flags & CFile::kDenyReadFlag))
				shareAccess |= FILE_SHARE_READ;

			if (flags & CFile::kShareWriteFlag)
				shareAccess |= FILE_SHARE_WRITE;

			// Set the create disposition.
			PGPUInt32	createDisp;

			if (flags & CFile::kCreateIfFlag)
				createDisp = FILE_OPEN_IF;
			else
				createDisp = FILE_OPEN;

			// Open the file handle.
			error.err = ZwCreateFile(&mFileHandle, accessMask, &objAttribs, 
				&ioStatus, NULL, FILE_ATTRIBUTE_NORMAL, shareAccess, 
				createDisp, createOpts, NULL, 0);

			if (error.HaveNonPGPError())
				error.pgpErr = kPGPError_CantOpenFile;

			if (error.IsntError())
			{
			#if	(_WIN32_WINNT < 0x0500)
				// Fix deadlock possibility.
				if (flags & CFile::kNoBufferingFlag)
					error = ReplacePagingIoResource();
			#endif	// _WIN32_WINNT < 0x5000

				if (error.IsntError())
				{
					mIsOpened = TRUE;
					mOpenFlags = flags;
				}

				if (error.IsError())
				{
					ZwClose(mFileHandle);
					mFileHandle = NULL;
				}
			}
		}
	}

	return error;
}

CComboError 
CFileImpDrvNT::Close()
{
	pgpAssert(IsOpened());

	CComboError	error;

#if	(_WIN32_WINNT < 0x0500)
	if (mOpenFlags & CFile::kNoBufferingFlag)
		RestorePagingIoResource();
#endif	// _WIN32_WINNT < 0x5000

	error.err = ZwClose(mFileHandle);

	if (error.HaveNonPGPError())
		error.pgpErr = kPGPError_FileOpFailed;

	if (error.IsntError())
	{
		mIsOpened	= FALSE;
		mFileHandle	= NULL;
		mOpenFlags	= CFile::kNoFlags;
	}

	return error;
}

CComboError 
CFileImpDrvNT::Delete(const char *path)
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
	return CComboError(CComboError::kPGPError, kPGPError_FeatureNotAvailable);
}

CComboError 
CFileImpDrvNT::Move(const char *oldPath, const char *newPath)
{
	pgpDebugMsg("PGPdisk: Unimplemented function.\n");
	return CComboError(CComboError::kPGPError, kPGPError_FeatureNotAvailable);
}

CComboError 
CFileImpDrvNT::Read(void *buf, PGPUInt64 pos, PGPUInt32 nBytes) const
{
	pgpAssert(IsOpened());
	pgpAssertAddrValid(buf, VoidAlign);

	CComboError		error;
	IO_STATUS_BLOCK	ioStatus;
	LARGE_INTEGER	bigPos;

	bigPos.QuadPart = pos;

	error.err = ZwReadFile(mFileHandle, NULL, NULL, NULL, &ioStatus, buf, 
		nBytes, &bigPos, NULL);

	if (error.HaveNonPGPError())
		error.pgpErr = kPGPError_ReadFailed;

	return error;
}

CComboError 
CFileImpDrvNT::Write(const void *buf, PGPUInt64 pos, PGPUInt32 nBytes) const
{
	pgpAssert(!IsReadOnly());
	pgpAssert(IsOpened());
	pgpAssertAddrValid(buf, VoidAlign);

	CComboError		error;
	IO_STATUS_BLOCK	ioStatus;
	LARGE_INTEGER	bigPos;

	bigPos.QuadPart = pos;

	error.err = ZwWriteFile(mFileHandle, NULL, NULL, NULL, &ioStatus, 
		const_cast<void *>(buf), nBytes, &bigPos, NULL);

	if (error.HaveNonPGPError())
		error.pgpErr = kPGPError_WriteFailed;

	return error;
}

CComboError 
CFileImpDrvNT::ReplacePagingIoResource()
{
	CComboError	error;

	// Replace the PagingIoResource of the FCB of the newly opened file
	// with our own. LONG story, see NTFSD archives for July 1999.

	// Reference the file object.
	PFILE_OBJECT	fileObject;

    error.err = ObReferenceObjectByHandle(mFileHandle, FILE_READ_DATA, NULL, 
		KernelMode, reinterpret_cast<void **>(&fileObject), NULL);

	if (error.HaveNonPGPError())
		error.pgpErr = kPGPError_NTDrvObjectOpFailed;

	if (error.IsntError())
	{
		// Get pointer to the PagingIoResource in the FCB header.
		mPPPagingIoResource	= reinterpret_cast<PERESOURCE *>(
			reinterpret_cast<PGPByte *>(fileObject->FsContext) + 
			kPagingIoResourceOffsetInFCB);

		// Save pointer to the old resource. It may be NULL.
		mPOldPagingIoResource = *mPPPagingIoResource;

		if (IsntNull(mPOldPagingIoResource))
		{
			// Create a new resource.
			error.err = ExInitializeResourceLite(&mNewPagingIoResource);

			if (error.HaveNonPGPError())
				error.pgpErr = kPGPError_SyncObjOpFailed;

			// Replace the old resource with the new resource.
			if (error.IsntError())
			{
				SafeReplaceResource(*mPPPagingIoResource, 
					&mNewPagingIoResource);
			}
		}

		ObDereferenceObject(fileObject);
	}

	return error;
}

void 
CFileImpDrvNT::RestorePagingIoResource()
{
	if (IsntNull(mPOldPagingIoResource))
	{
		// Put back the old resource.
		SafeReplaceResource(*mPPPagingIoResource, mPOldPagingIoResource);

		// Delete the new resource.
		ExDeleteResourceLite(&mNewPagingIoResource);

		mPPPagingIoResource = NULL;
		mPOldPagingIoResource = NULL;
	}
}

void 
CFileImpDrvNT::SafeReplaceResource(
	PERESOURCE&	pOldResource, 
	PERESOURCE	pNewResource)
{
	pgpAssertAddrValid(pOldResource, ERESOURCE);
	pgpAssertAddrValid(pNewResource, ERESOURCE);

	LARGE_INTEGER	shortWait	= RtlConvertLongToLargeInteger(0 - 
		static_cast<PGPInt32>(kReplaceResourceWaitIntervalMs * 
		PFLConstants::kHundredNsPerMs));

	while (TRUE)
	{
		// Acquire the old resource exclusive.
		ExAcquireResourceExclusiveLite(pOldResource, TRUE);

		// Go to dispatch level and check if there are waiters.
		KIRQL	oldIrql;
		KeRaiseIrql(DISPATCH_LEVEL, &oldIrql);

		if ((ExGetExclusiveWaiterCount(pOldResource) == 0) && 
			ExGetSharedWaiterCount(pOldResource) == 0)
		{
			// There are no waiters, perform the replace.
			PERESOURCE	tempOldResource	= pOldResource;
			pOldResource = pNewResource;

			ExReleaseResourceForThreadLite(tempOldResource, 
				ExGetCurrentResourceThread());

			KeLowerIrql(oldIrql);

			break;
		}
		else
		{
			// There are waiters, release the exclusive lock and try again.
			ExReleaseResourceForThreadLite(pOldResource, 
				ExGetCurrentResourceThread());
			KeLowerIrql(oldIrql);

			// Let the waiters have a chance to finish.
			KeDelayExecutionThread(KernelMode, TRUE, &shortWait);
		}
	}
}

PGPBoolean 
CFileImpDrvNT::IsUNCPath(const char *path) const
{
	pgpAssertStrValid(path);

	if (!IsSlashChar(path[0]) || !IsSlashChar(path[1]))
		return FALSE;

	PGPUInt32	length		= strlen(path);
	PGPUInt32	numSlashes	= 0;

	for (PGPUInt32	i = 0; i < length; i++)
	{
		if (IsSlashChar(path[i]))
			numSlashes++;

		if (numSlashes >= 3)
			return TRUE;
	}

	return FALSE;
}
